from openai import OpenAI
import os 
import time
import json
import re
from tqdm import tqdm

client = OpenAI(api_key="", organization="")
def single_chat_gpt_wrapper(messages:list, temperature:int=0.7) -> str:
    for _ in range(2):
        try:
            response = client.chat.completions.create(
                model="gpt-4-turbo-preview",
                messages=messages,
                temperature=temperature,
                # max_tokens=max_tokens,
            )
            text_content_response = response.choices[0].message.content
            return text_content_response
        except KeyboardInterrupt:
            raise KeyboardInterrupt
        except Exception as e:
            print(e)
            time.sleep(5)
    return None

def getProblem(exemplars:list, generate_num:int = 5) -> list:
    prompt = f"Generate {generate_num} coding queries and tests in the format of:\n<query> ... </query>\n"
    prompt += f"Do not include test information in the queries.\n"
    prompt += f"Here are some examples:\n"
    for j, exemplar in enumerate(exemplars):
        prompt += f"<query> {exemplar['query']} </query>\n\n"
    prompt += "Again, generate the queries following the examples strictly, in the format of:\n<query> ... </query>\n"
    return prompt

def save_data(res:str):
    # match_list = [match.span() for match in re.finditer("code>", res)]
    # codes = [res[match_list[i*2][1]:match_list[i*2+1][0]-2].strip() for i in range(len(match_list) // 2)]
    match_list = [match.span() for match in re.finditer("query>", res)]
    problems = [res[match_list[i*2][1]:match_list[i*2+1][0]-2].strip() for i in range(len(match_list) // 2)]
    # match_list = [match.span() for match in re.finditer("test>", res)]
    # tests = [res[match_list[i*2][1]:match_list[i*2+1][0]-2].strip() for i in range(len(match_list) // 2)]

    with open("new_data/problem.jsonl", "a") as f:
        for problem in problems:
            prompt = {'query':problem}
            f.write(json.dumps(prompt) + "\n")

def genProblem():
    # print(problem_dicts[0])
    for i in tqdm(range(DataNum // 15)):
        st = TRAIN_IDX_START + i * 15
        for rp_time in range(1):
            # input = [{"role": "system", "content": "Be a coding problem creator, create coding queries, python codes, and tests."}, 

            input = [{"role": "system", "content": "Be a coding problem creator and create coding queries."}, 
                    {"role": "user", "content": getProblem(problem_dicts[st:st+15], generate_num=100)},
                    ] 
            response = single_chat_gpt_wrapper(input, temperature=1)
            print(response)
            save_data(response)

def save_code(problem:str, res:str):
    match_list = [match.span() for match in re.finditer("code>", res)]
    codes = [res[match_list[i*2][1]:match_list[i*2+1][0]-2].strip() for i in range(len(match_list) // 2)]
    match_list = [match.span() for match in re.finditer("test>", res)]
    tests = [res[match_list[i*2][1]:match_list[i*2+1][0]-2].strip() for i in range(len(match_list) // 2)]

    with open("new_data/full.jsonl", "a") as f:
        try:
            prompt = {'query':problem, 'code':codes[-1], 'tests':[tests[-1]]}
            f.write(json.dumps(prompt) + "\n")
        except:
            pass

def getCode(exemplars:list, query:str) -> list:
    prompt = f"Generate python codes and tests given coding queries in the format of:\n<code> ... </code>\n<test> ... </test>\n"
    prompt += f"Only include one assertion expression in the test.\n"
    prompt += f"Here are some examples:\n"
    for j, exemplar in enumerate(exemplars):
        prompt += f"<query> {exemplar['query']} </query>\n<code> {exemplar['code']} </code>\n<test> {exemplar['test']} </test>\n\n"
    prompt += "Again, generate the python code and test following the examples strictly, in the format of:\n<code> ... </code>\n<test> ... </test>\n"
    prompt += f"<query> {query} </query>\n"
    return prompt

if __name__ == "__main__":
    with open("data/python/mbpp/ic_mbpp.json") as f:
        problem_dicts = json.load(f)
    TRAIN_IDX_START = 511
    DataNum = 15
    problem_dicts = problem_dicts[TRAIN_IDX_START+375:TRAIN_IDX_START+375 + DataNum]
    # genProblem()
    # genProblem()
    with open('./new_data/problem.jsonl', 'r') as json_file:
        json_list = list(json_file)
    p_set = set()
    p_list = list(open('./new_data/full.jsonl', 'r'))
    for j, json_str in enumerate(p_list):
        problem = json.loads(json_str)
        p_set.add(problem['query'])
    for j, json_str in tqdm(enumerate(json_list[-100:])):
        problem = json.loads(json_str)
        if problem['query'] not in p_set:
            print(j)
            example_idx = (j // 100) * 15 + TRAIN_IDX_START
            input = [{"role": "system", "content": "Be a coding assistant, create python codes and tests given queries."}, 
            {"role": "user", "content": getCode(problem_dicts[example_idx:example_idx+1], problem["query"])},
            ] 
            response = single_chat_gpt_wrapper(input, temperature=0)
            print(response)
            save_code(problem["query"], response)